#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

#requires: all_tests

from __future__ import division
from __future__ import print_function
from builtins import zip
from builtins import str

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boututils.datafile import DataFile
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join, isfile

import pickle

from sys import stdout

## Get the tokamak shape
## Note: This must be the same as in the mms.py code
from boutdata.mms import SimpleTokamak
shape = SimpleTokamak()



print("Making MMS tokamak geometry test")
shell_safe("make > make.log")

# List of NX values to use
nxlist = [4, 8, 16, 32]#, 64]#, 128]#, 256]
nprocs = [2, 4,  4,  8,  8,  8, 8] 

success = True

varlist = ["advect", "delp2", "laplacepar"]
markers = ['bo', 'r^', "gs"]
labels = [r'$\left[\phi, f\right]$', r'$\nabla_\perp^2$', r'$\nabla_{||}^2$']

error_2 = {}
error_inf = {}
for var in varlist:
    error_2[var]   = []  # The L2 error (RMS)
    error_inf[var] = []  # The maximum error
dx=[]

for nx,nproc in zip(nxlist, nprocs):
    # Generate a new mesh file
    
    filename = "grid%d.nc" % nx
    
    if isfile(filename):
        print("Grid file '%s' already exists" % filename)
    else:
        print("Creating grid file '%s'" % filename)
        f = DataFile(filename, create=True)
        shape.write(nx,nx, f)
        f.close()
    
    args = " MZ="+str(nx)+" grid="+filename #+" solver:timestep="+str(0.001/nx)
        
    print("  Running with " + args)

    # Delete old data
    shell("rm data/BOUT.dmp.*.nc")
        
    # Command to run
    cmd = "./tokamak "+args
    # Launch using MPI
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        
    # Save output to log file
    f = open("run.log."+str(nx), "w")
    f.write(out)
    f.close()

    dx.append(1./nx)
    for var in varlist:
        # Collect data
        E = collect("E_"+var, tind=[1,1], info=False, path="data")
        E = E[:,2:-2, :,:]
        
        # Average error over domain
        l2 = sqrt(mean(E**2))
        linf = max(abs( E ))
    
        error_2[var].append( l2 )
        error_inf[var].append( linf )

        print("%s : l-2 %f l-inf %f" % (var, l2, linf))
    if len(dx) > 1:
        # Calculate convergence order
        for var,mark,label in zip(varlist, markers, labels):
            order = log(error_2[var][-1] / error_2[var][-2]) / log(dx[-1] / dx[-2])
            stdout.write("%s Convergence order = %f" % (var, order))
    
            if 1.8 < order < 2.2: # Should be second order accurate
                print("............ PASS")
            else:
                success = False
                print("............ FAIL")


# Save data
with open("tokamak.pkl", "wb") as output:
    pickle.dump(nxlist, output)
    pickle.dump(error_2, output)
    pickle.dump(error_inf, output)


    
# plot errors
try:
    import matplotlib.pyplot as plt
    
    plt.figure()

    # Calculate convergence order
    for var,mark,label in zip(varlist, markers, labels):
        plt.plot(dx, error_2[var], '-'+mark, label=label)
        plt.plot(dx, error_inf[var], '--'+mark)


    plt.legend(loc="upper left")
    plt.grid()
    
    plt.yscale('log')
    plt.xscale('log')
        
    plt.xlabel(r'Mesh spacing $\delta x$')
    plt.ylabel("Error norm")
    
    plt.savefig("norm.pdf")

    #plt.show()
    plt.close()
except:
    pass

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
